import torch

from nn_conv import output

outputs=torch.tensor([[0.1,0.2],
                    [0.3,0.4]])

print(outputs.argmax(1))
#1横向看，0纵向看
preds=outputs.argmax(1)
targets=torch.tensor([0,1])
print((preds==targets).sum())

